package dict

import (
	"github.com/DiracLee/dires-go/utils"
	"math"
	"math/rand"
	"sync"
	"sync/atomic"
)

type concurrentDict struct {
	table     []*shard
	count     int64
	tableSize uint32
}

// NewConcurrent returns a syncx dict.
func NewConcurrent(tableSize uint32) Dict {
	return makeConcurrentDict(tableSize)
}

func makeConcurrentDict(tableSize uint32) *concurrentDict {
	tableSize = computeTableSize(tableSize)
	table := make([]*shard, tableSize)
	for i := range table {
		table[i] = &shard{
			m:     make(map[string]interface{}),
			mutex: sync.RWMutex{},
		}
	}
	return &concurrentDict{
		table:     table,
		count:     0,
		tableSize: tableSize,
	}
}

// Len returns number of KVs.
func (dict *concurrentDict) Len() int {
	return int(atomic.LoadInt64(&dict.count))
}

// Get returns value corresponding to `key` and whether `key` exists.
func (dict *concurrentDict) Get(key string) (val interface{}, exists bool) {
	index := utils.GetIndex(len(dict.table), key)
	shard := dict.table[index]
	shard.mutex.RLock()
	defer shard.mutex.RUnlock()
	val, exists = shard.m[key]
	return
}

// PutOrSet put a KV and returns 1 if `key` not exists, otherwise, set `val` for the `key` and returns 0.
func (dict *concurrentDict) PutOrSet(key string, val interface{}) (result int) {
	shard := dict.getShard(key)
	shard.mutex.Lock()
	defer shard.mutex.Unlock()

	_, exists := shard.m[key]
	shard.m[key] = val
	if exists {
		return 0
	}
	atomic.AddInt64(&dict.count, 1)
	return 1
}

// PutIfNotExists put a KV if `key` not exists.
func (dict *concurrentDict) PutIfNotExists(key string, val interface{}) (result int) {
	shard := dict.getShard(key)
	shard.mutex.Lock()
	defer shard.mutex.Unlock()

	if _, exists := shard.m[key]; exists {
		return 0
	}
	shard.m[key] = val
	atomic.AddInt64(&dict.count, 1)
	return 1
}

// SetIfExists set `val` for the `key` if `key` exists.
func (dict *concurrentDict) SetIfExists(key string, val interface{}) (result int) {
	shard := dict.getShard(key)
	shard.mutex.Lock()
	defer shard.mutex.Unlock()

	if _, exists := shard.m[key]; !exists {
		return 0
	}
	shard.m[key] = val
	return 1
}

func (dict *concurrentDict) remove(key string) (result int) {
	shard := dict.getShard(key)
	shard.mutex.Lock()
	defer shard.mutex.Unlock()

	if _, exists := shard.m[key]; !exists {
		return 0
	}
	delete(shard.m, key)
	atomic.AddInt64(&dict.count, -1)
	return 1
}

// Remove removes element with `key` and returns 1 if `key` exists, otherwise, returns 0.
func (dict *concurrentDict) Remove(keys ...string) (result int) {
	for _, key := range keys {
		result += dict.remove(key)
	}
	return result
}

// ForEach consumes element one by one with `consumer` until all elements run out or any consumer returns false;
//         returns whether all consumers return true.
func (dict *concurrentDict) ForEach(consumer Consumer) bool {
	for _, shard := range dict.table {
		shard.mutex.RLock()
		continues := func() bool {
			defer shard.mutex.RUnlock()
			for k, v := range shard.m {
				if !consumer(k, v) {
					return false
				}
			}
			return true
		}()
		if !continues {
			return false
		}
	}
	return true
}

// Keys returns slice of all keys.
func (dict *concurrentDict) Keys() []string {
	keys := make([]string, 0, dict.Len())
	dict.ForEach(func(key string, val interface{}) bool {
		keys = append(keys, key)
		return true
	})
	return keys
}

// RandomKeys returns `limit` keys, which may be duplicated.
func (dict *concurrentDict) RandomKeys(limit int) []string {
	keys := make([]string, 0, dict.Len())
	for i := 0; i < limit; {
		key := dict.randomKey()
		if key == "" {
			continue
		}
		keys = append(keys, key)
		i++
	}
	return keys
}

// RandomDistinctKeys returns at most `limit` keys, which are distinct.
func (dict *concurrentDict) RandomDistinctKeys(limit int) []string {
	if limit > dict.Len() {
		return dict.Keys()
	}
	keySet := make(map[string]struct{})
	for len(keySet) < limit {
		key := dict.randomKey()
		if key == "" {
			continue
		}
		keySet[key] = struct{}{}
	}

	keys := make([]string, 0, dict.Len())
	for key := range keySet {
		keys = append(keys, key)
	}
	return keys
}

// Clear remove all KVs.
func (dict *concurrentDict) Clear() {
	*dict = *makeConcurrentDict(dict.tableSize)
}

func (dict *concurrentDict) getShard(key string) *shard {
	index := utils.GetIndex(len(dict.table), key)
	return dict.table[index]
}

func (dict *concurrentDict) randomKey() string {
	index := rand.Uint32() % dict.tableSize
	shard := dict.table[index]
	return shard.randomKey()
}

func computeTableSize(tableSize uint32) uint32 {
	if tableSize <= 16 {
		return 16
	}
	n := tableSize - 1
	n |= n >> 1
	n |= n >> 2
	n |= n >> 4
	n |= n >> 8
	n |= n >> 16
	if n < 0 {
		return math.MaxUint32
	}
	return n + 1
}

type shard struct {
	m     map[string]interface{}
	mutex sync.RWMutex
}

func (shard *shard) randomKey() string {
	shard.mutex.RLock()
	defer shard.mutex.RUnlock()
	for k := range shard.m {
		return k
	}
	return ""
}
